# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Dense head."""

from typing import Optional, List, Union

from mindspore import nn

from mindvision.common.utils.class_factory import ClassFactory, ModuleType


@ClassFactory.register(ModuleType.HEAD)
class DenseHead(nn.Cell):
    """
    LinearClsHead architecture.

    Args:
        input_channel (int) – The number of channels in the input space.
        num_classes (int): Number of classes.
        has_bias (bool) – Specifies whether the layer uses a bias vector. Default: True.
        activation (Union[str, Cell, Primitive]) – activate function applied to the output of the fully connected
        layer, eg. ‘ReLU’.Default: None.
        has_dropout (bool): Is dropout used. Default is false
        keep_prob (float): The keep rate, greater than 0 and less equal than 1. E.g. rate=0.9, dropping out 10% of
        input units. Default: 0.5.
    Returns:
        Tensor, output tensor.
    """

    def __init__(self,
                 input_channel: int,
                 num_classes: int,
                 has_bias: bool = True,
                 activation: Optional[Union[str, nn.Cell]] = None,
                 has_dropout: bool = False,
                 keep_prob: float = 0.5,
                 ) -> None:
        super(DenseHead, self).__init__()
        head: List[nn.Cell] = []

        if has_dropout:
            head.extend(
                [nn.Dropout(keep_prob),
                 nn.Dense(input_channel, num_classes, has_bias=has_bias, activation=activation)]
            )
        else:
            head.append(
                nn.Dense(input_channel, num_classes, has_bias=has_bias, activation=activation)
            )

        self.classifier = nn.SequentialCell(head)

    def construct(self, x):
        x = self.classifier(x)
        return x
